#ifndef GRADIENT_GENERATOR_H
#define GRADIENT_GENERATOR_H

#include "Loss_generator.h"
#include "Indata_function.h"
#include <Windows.h>
#pragma warning(disable: n_label996)

class Gradient_generator : public Loss_generator {
	private:
		struct app1_gradient* gradient;
	public:
		Gradient_generator(problem, NN_data*, User_parameter*, struct app1_indata*, struct app1_solution*, struct app1_gradient*);
		void calc_grad_expected_demand(void);
		void calc_grad_ccon();
		void calc_grad_ccon_batch();
		void calc_grad_test_case_ccon();
		void clear(void);
		~Gradient_generator() {  }
};

Gradient_generator::Gradient_generator(problem pr, NN_data* nn_data, User_parameter* user_parameter, struct app1_indata* indata, struct app1_solution* solution, struct app1_gradient* gradient) : Loss_generator(pr, nn_data, user_parameter, indata, solution) {
	this->gradient = gradient;

}

void Gradient_generator::calc_grad_expected_demand() {
	for (int l = 0;l < n_label;l++)
		for (int i = 0;i < nn_data->n_sample;i++)
		{
			gradient->batch_expected_gradient[l][i] = -pr.k[l]*(solution->mainsol_batch[i][l][1] - solution->mainsol_batch[i][l][0]) / user_parameter->lambda;
			gradient->batch_expected_gradient[l][i] += 0.0001*(nn_data->batch_expected[l][i] - nn_data->batch_ans[l][i]);
		}

}




void Gradient_generator::calc_grad_ccon_batch() {
	normalclass nc = indata->nc;
	for (int l = 0; l < n_label; l++) {
		for (int j = 0; j < nn_data->n_sample; j++) {
			if (rand() % 10)
				continue;
			double approx_loss_class[10] = {};
			double expsum = 0;
			for (int i = 0; i < nn_data->n_class; i++) {
				int tmp = nc.nclass[l][j];
				nc.nclass[l][j] = i;
								approx_loss_class[i] = tmp==i?0:this->calc_approx_loss_batch( j, nc, indata->p)-this->solution->loss[j];
				expsum += exp(nn_data->ccon_w_batch[l][i][j]);
				nc.nclass[l][j] = tmp;
			}
			for (int i = 0; i < nn_data->n_class; i++) {
				gradient->ccon_gradient_batch[l][i][j] = 0;
				for (int k = 0; k < nn_data->n_class; k++) {
					if (k == i) gradient->ccon_gradient_batch[l][i][j] += approx_loss_class[k] * (exp(nn_data->ccon_w_batch[l][k][j]) / expsum);
					gradient->ccon_gradient_batch[l][i][j] += -approx_loss_class[k] * (exp(nn_data->ccon_w_batch[l][k][j]) * exp(nn_data->ccon_w_batch[l][i][j]) / expsum / expsum);
				}
				gradient->ccon_gradient_batch[l][i][j] /= nn_data->n_sample;
			}
		}
	}
}

void Gradient_generator::calc_grad_test_case_ccon( ) { 
	normaltable nt = indata->nt;
	for (int l = 0;l < n_label;l++)
	{
		double plusoneone_approxloss[5][10] = {};
		double  plusone_minusxi_approxloss[5][10] = {};
		double plusone_plusxi_approxloss[5][10] = {};
		double  minusoneone_approxloss[5][10] = {};
		double minusone_minusxi_approxloss[5][10] = {};
		double  minusone_plusxi_approxloss[5][10] = {};
		for (int i = 0; i < user_parameter->n_rank; i++) {
			for (int j = 0; j < nn_data->n_class; j++) {
								nt.table[l][i][j][0]++;
				nt.table[l][i][j][1]++;
				plusoneone_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt))-this->solution->loss[0];
								nt.table[l][i][j][1]--;
				plusone_minusxi_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt)) - this->solution->loss[0];
								nt.table[l][i][j][0]--;
				nt.table[l][i][j][1]++;
				plusone_plusxi_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt)) - this->solution->loss[0];
				nt.table[l][i][j][1]--;
								if (nt.table[l][i][j][1]) {
					nt.table[l][i][j][1]--;
					minusone_plusxi_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt)) - this->solution->loss[0];
					nt.table[l][i][j][1]++;
				}
								if (nt.table[l][i][j][0]) {
					nt.table[l][i][j][0]--;
					minusone_minusxi_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt)) - this->solution->loss[0];
					nt.table[l][i][j][0]++;
				}
								if (nt.table[l][i][j][0] && nt.table[l][i][j][1]) {
					nt.table[l][i][j][0]--;
					nt.table[l][i][j][1]--;
					minusoneone_approxloss[i][j] = this->calc_approx_loss_batch( 0, indata->nc, calc_posterior(*nn_data, *user_parameter, indata, nt)) - this->solution->loss[0];
					nt.table[l][i][j][0]++; 					nt.table[l][i][j][1]++; 				}
			}
		}
		for (int j = 0; j < user_parameter->n_internal_test_case; j++) {
			double expsum = 0;
			for (int i = 0; i < nn_data->n_class; i++) {
				expsum += exp(nn_data->test_case_ccon_w[l][i][j]);
			}
			for (int i = 0; i < nn_data->n_class; i++) {
				double gradient_tmp = 0.;
				for (int k = 0; k < nn_data->n_class; k++) {
																									double mxi_expsum = expsum + exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) - exp(nn_data->test_case_ccon_w[l][k][j]);
					double pxi_expsum = expsum + exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) - exp(nn_data->test_case_ccon_w[l][k][j]);
					if (nt.table_eff[l][k][j][0] && nt.table_eff[l][k][j][1]) {
												if (k == i) {
														gradient_tmp += ( - minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) / mxi_expsum;
							gradient_tmp += (minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k] - minusoneone_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) / pxi_expsum;
							gradient_tmp += -(- minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] - user_parameter->xi) / mxi_expsum / mxi_expsum;
							gradient_tmp += -(minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k] - minusoneone_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] + user_parameter->xi) / pxi_expsum / pxi_expsum;
						}
						else {
							gradient_tmp += -( - minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / mxi_expsum / mxi_expsum;
							gradient_tmp += -(minusone_minusxi_approxloss[nn_data->ans[l][j]<3][k] - minusoneone_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / pxi_expsum / pxi_expsum;
						}
					}
					else if (nt.table_eff[l][k][j][1]) {
												if (k == i) {
														gradient_tmp += (plusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) / mxi_expsum;
							gradient_tmp += (- minusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) / pxi_expsum;
							gradient_tmp += -(plusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] - user_parameter->xi) / mxi_expsum / mxi_expsum;
							gradient_tmp += -( - minusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] + user_parameter->xi) / pxi_expsum / pxi_expsum;
						}
						else {
							gradient_tmp += -(plusone_minusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / mxi_expsum / mxi_expsum;
							gradient_tmp += -( - minusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / pxi_expsum / pxi_expsum;
						}
					}
					else {
												if (k == i) {
														gradient_tmp += (plusoneone_approxloss[nn_data->ans[l][j]<3][k] - plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) / mxi_expsum;
							gradient_tmp += (plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) / pxi_expsum;
							gradient_tmp += -(plusoneone_approxloss[nn_data->ans[l][j]<3][k] - plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] - user_parameter->xi) / mxi_expsum / mxi_expsum;
							gradient_tmp += -(plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k] ) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j] + user_parameter->xi) / pxi_expsum / pxi_expsum;
						}
						else {
							gradient_tmp += -(plusoneone_approxloss[nn_data->ans[l][j]<3][k] - plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k]) * exp(nn_data->test_case_ccon_w[l][k][j] - user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / mxi_expsum / mxi_expsum;
							gradient_tmp += -(plusone_plusxi_approxloss[nn_data->ans[l][j]<3][k] ) * exp(nn_data->test_case_ccon_w[l][k][j] + user_parameter->xi) * exp(nn_data->test_case_ccon_w[l][i][j]) / pxi_expsum / pxi_expsum;
						}
					}
				}
				gradient->test_case_ccon_gradient[l][i][j] = gradient_tmp;
			}
		}
			}

}

void Gradient_generator::clear() {
	memset(gradient->ccon_gradient_batch, 0, sizeof(gradient->ccon_gradient_batch));
	memset(gradient->test_case_ccon_gradient, 0, sizeof(gradient->test_case_ccon_gradient));
}




#endif